-
Notifications
You must be signed in to change notification settings - Fork 169
Feat: update eagle3 example; add export #293
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
WalkthroughReworks the speculative-decoding example into a ModelOpt / Hugging Face–centric end-to-end pipeline: README overhaul, new training/export scripts and SLURM data guide, CLI/calibration/eagle_config changes, HF export plugin and unified export integration, Eagle forward/training logic extensions, and minor telemetry/server tweaks. Changes
Sequence Diagram(s)sequenceDiagram
autonumber
participant User
participant TrainScript as train_eagle3_and_export.sh
participant Launch as launch_train.sh
participant Trainer as Trainer/ModelOpt
participant Validator as ar_validate.py
participant Exporter as export_hf_checkpoint.py
participant Unified as unified_export_hf.py
participant Plugin as hf_spec_export.py
User->>TrainScript: run (base_model, num_gpu, data)
TrainScript->>Launch: ./launch_train.sh (MODE=eagle3 ...)
Launch->>Trainer: start training (Eagle architecture)
Trainer-->>TrainScript: checkpoint → ckpts/<model>-<ts>
TrainScript->>Validator: python ar_validate.py --model_path <ckpt>
Validator-->>TrainScript: AR metrics (optionally → wandb)
TrainScript->>Exporter: python export_hf_checkpoint.py --model_path <ckpt> --export_path <export>
Exporter->>Unified: export_hf_checkpoint(model, export_dir)
Unified->>Plugin: rename_and_prune_if_spec_decoding(model, state_dict)
Unified->>Plugin: set_config_if_spec_decoding(model, config)
Unified-->>Exporter: saved model/config (+quant cfg)
Exporter-->>User: "Exported checkpoint to <export>"
sequenceDiagram
autonumber
participant HFModel as HFEagleModel
participant Base as Base LLM
participant Eagle as Eagle Module
Note over HFModel: Forward supports optional base_model_outputs and draft-vocab mapping
HFModel->>Base: optionally call forward / or accept provided base_model_outputs
Base-->>HFModel: hidden_states / logits (or omitted)
HFModel->>HFModel: _map_logits_to_draft_vocab(full_logits) -- when draft_vocab != vocab
HFModel->>Eagle: multi-step eagle forward (steps 0..3) with aux states
Eagle-->>HFModel: eagle_logits per step
HFModel->>HFModel: _eagle_loss -> (reg_loss, cls_loss, accuracy_k)
HFModel-->>Caller: ModelOutput(..., train_acc=(acc0..acc3), losses...)
Estimated code review effort🎯 5 (Critical) | ⏱️ ~120 minutes Suggested reviewers
Pre-merge checks (1 passed, 1 warning, 1 inconclusive)❌ Failed Checks (1 warning, 1 inconclusive)
✅ Passed Checks (1 passed)
Poem
Tip 👮 Agentic pre-merge checks are now available in preview!Pro plan users can now enable pre-merge checks in their settings to enforce checklists before merging PRs.
Example: reviews:
pre_merge_checks:
custom_checks:
- name: "Undocumented Breaking Changes"
mode: "warning"
instructions: |
Flag potential breaking changes that are not documented:
1. Identify changes to public APIs/exports, CLI flags, environment variables, configuration keys, database schemas, or HTTP/GraphQL endpoints (including removed/renamed items and changes to types, required params, return values, defaults, or behavior).
2. Ignore purely internal/private changes (e.g., code not exported from package entry points or marked internal).
3. Verify documentation exists: a "Breaking Change" section in the PR description and updates to CHANGELOG.md. ✨ Finishing Touches
🧪 Generate unit tests
Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
e6334c4
to
47a0a50
Compare
Codecov Report✅ All modified and coverable lines are covered by tests. Additional details and impacted files@@ Coverage Diff @@
## main #293 +/- ##
=======================================
Coverage 73.93% 73.93%
=======================================
Files 172 172
Lines 17408 17408
=======================================
Hits 12871 12871
Misses 4537 4537 ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 6
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/server_generate.py (1)
156-158
: Fix exception handler referencing undefinedprompt
in chat mode.In chat path,
prompt
is never defined; on errors this raises a new exception and hides the real one.Apply one of these minimal fixes:
- except Exception as e: - print(e) - print(prompt) - print("Failed to generate data") + except Exception as e: + print(e) + if "prompt" in locals(): + print(prompt) + else: + print("prompt not set (chat mode)") + print("Failed to generate data")Or define
prompt = None
at function start and print only if notNone
.
🧹 Nitpick comments (30)
modelopt/torch/export/plugins/hf_spec_export.py (4)
16-16
: Fix typos in module docstring."Modifiy stated_dict" → "Modify state_dict".
-"""Modifiy stated_dict and config for exporting speculative decoding in official format.""" +"""Modify state_dict and config for exporting speculative decoding in official format."""
23-24
: Remove or use unused SPECULATIVE_DECODING_MODES.Currently unused; either hook it into the guards or drop it to avoid dead code.
-SPECULATIVE_DECODING_MODES = ["eagle", "medusa"] +# Reserved for future multi-mode handling: +# SPECULATIVE_DECODING_MODES = ["eagle", "medusa"]
25-44
: Rename constant to fix typo and improve clarity.EALGE_MODELOPT_TO_OFFICIAL → EAGLE_MODELOPT_TO_OFFICIAL; update references.
-EALGE_MODELOPT_TO_OFFICIAL = { +EAGLE_MODELOPT_TO_OFFICIAL = { @@ - _check_state_dict_keys_match(model.eagle_module, EALGE_MODELOPT_TO_OFFICIAL["required"]) + _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"]) @@ - **EALGE_MODELOPT_TO_OFFICIAL["required"], - **EALGE_MODELOPT_TO_OFFICIAL["optional"], + **EAGLE_MODELOPT_TO_OFFICIAL["required"], + **EAGLE_MODELOPT_TO_OFFICIAL["optional"],Also applies to: 65-66, 70-72
68-79
: Avoid repeated state_dict() calls and use provided post_state_dict for fallback.Slight perf/clarity win and safer fallback when lm_head lives only in the passed-in state dict.
- export_state_dict = {} - for ours_key, export_key in { - **EALGE_MODELOPT_TO_OFFICIAL["required"], - **EALGE_MODELOPT_TO_OFFICIAL["optional"], - }.items(): - if ours_key in model.eagle_module.state_dict(): - export_state_dict[export_key] = model.eagle_module.state_dict()[ours_key] + module_state = model.eagle_module.state_dict() + export_state_dict = {} + for ours_key, export_key in { + **EAGLE_MODELOPT_TO_OFFICIAL["required"], + **EAGLE_MODELOPT_TO_OFFICIAL["optional"], + }.items(): + if ours_key in module_state: + export_state_dict[export_key] = module_state[ours_key] @@ - if "eagle_lm_head.weight" not in model.eagle_module.state_dict(): - export_state_dict["lm_head.weight"] = model.state_dict()["lm_head.weight"] + if "eagle_lm_head.weight" not in module_state: + if "lm_head.weight" in post_state_dict: + export_state_dict["lm_head.weight"] = post_state_dict["lm_head.weight"] + else: + # Fall back to model.state_dict() if needed + base_state = model.state_dict() + if "lm_head.weight" in base_state: + export_state_dict["lm_head.weight"] = base_state["lm_head.weight"] + else: + raise KeyError("lm_head.weight not found in post_state_dict or model.state_dict()")modelopt/torch/speculative/plugins/transformers.py (4)
721-727
: Align reverse_mapping device with logits to avoid cross-device indexing.Safer if full_logits is on a different device than d2t.
- reverse_mapping = ( - torch.arange(len(self.eagle_module.d2t)).to(self.eagle_module.d2t.device) - + self.eagle_module.d2t - ) + reverse_mapping = ( + torch.arange(len(self.eagle_module.d2t), device=full_logits.device, dtype=torch.long) + + self.eagle_module.d2t.to(full_logits.device) + ) return full_logits[:, :, reverse_mapping]
856-858
: Nit: fix comment typo."diabled" → "disabled".
- # NOTE: diabled for now. + # NOTE: disabled for now.
1052-1058
: Compute masked accuracy over valid positions only.Current mean divides by B*T; use sum over mask / mask.sum() for accurate reporting when masks zero-out prefixes.
- base_predict_tok = base_model_logits.argmax(dim=-1) - eagle_predict_tok = eagle_logits.argmax(dim=-1) - accuracy = ( - (loss_mask[:, :, 0] * (base_predict_tok == eagle_predict_tok)).float().mean().item() - ) - accuracy = round(accuracy, 3) + base_predict_tok = base_model_logits.argmax(dim=-1) + eagle_predict_tok = eagle_logits.argmax(dim=-1) + valid = loss_mask[:, :, 0].bool() + correct = (base_predict_tok == eagle_predict_tok) & valid + denom = valid.sum().clamp_min(1).float() + accuracy = round(correct.sum().float().div(denom).item(), 3)
716-717
: Add explicitd2t
buffer assertion before vocab remap
Although gating ondraft_vocab_size != vocab_size
is correct, insert an assertion (e.g.assert hasattr(self, "d2t"), "d2t buffer not initialized"
) immediately before calling_map_logits_to_draft_vocab
to surface misconfigurations more clearly.modelopt/torch/export/plugins/__init__.py (1)
18-24
: Optional: make public surface explicit.Consider defining all in hf_spec_export to avoid star-import drift.
with import_plugin("transformers"): - from .hf_spec_export import * + from .hf_spec_export import * # relies on hf_spec_export.__all__examples/speculative_decoding/export_hf_checkpoint.py (2)
25-29
: Make CLI args required and self-documentingAvoid silent defaults. Require both paths and add help for better UX.
-def parse_args(): - parser = argparse.ArgumentParser() - parser.add_argument("--model_path", type=str, default="") - parser.add_argument("--export_path", type=str, default="") - return parser.parse_args() +def parse_args(): + parser = argparse.ArgumentParser(description="Export a HF checkpoint (with ModelOpt state) for deployment.") + parser.add_argument("--model_path", type=str, required=True, help="Path or HF hub id of the trained checkpoint.") + parser.add_argument("--export_path", type=str, required=True, help="Destination directory for exported files.") + return parser.parse_args()
34-41
: Set eval mode prior to exportEnsures deterministic modules (e.g., dropout) are disabled during the dummy forward used in export.
args = parse_args() -model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto") +model = AutoModelForCausalLM.from_pretrained(args.model_path, torch_dtype="auto") +model.eval() with torch.inference_mode(): export_hf_checkpoint( model, # The quantized model. export_dir=args.export_path, # The directory where the exported files will be stored. )modelopt/torch/export/unified_export_hf.py (2)
497-503
: Make_quant_applied
resilient to missing keysMinor polish: use
dict.get
to avoid KeyError if the structure changes.-def _quant_applied(hf_quant_config: dict) -> bool: - """Check if any quantization is applied.""" - return not ( - hf_quant_config["quantization"]["quant_algo"] == QUANTIZATION_NONE - and not hf_quant_config["quantization"]["quantized_layers"] - ) +def _quant_applied(hf_quant_config: dict) -> bool: + """Check if any quantization is applied.""" + q = hf_quant_config.get("quantization", {}) + return not (q.get("quant_algo") == QUANTIZATION_NONE and not q.get("quantized_layers"))
30-33
: Optional: avoid hard import if plugins are unusedIf plugins are only relevant for Eagle exports, consider importing inside
export_hf_checkpoint
to reduce import-time side effects and avoid potential cycles.examples/speculative_decoding/train_eagle3_and_export.sh (4)
53-53
: Fix ShellCheck SC2155: avoid command substitution in exportAssign first, then export to prevent masking return values.
- export CUDA_VISIBLE_DEVICES=$(seq -s, 0 $((NUM_GPU-1))) + devs="$(seq -s, 0 $((NUM_GPU-1)))" + export CUDA_VISIBLE_DEVICES="$devs"
58-66
: Create output dirs proactivelyAvoid relying on downstream scripts to mkdir. Harmless if they already exist.
echo "==== [1/3] Training draft model ====" OUTPUT_DIR=ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M) +mkdir -p "$(dirname "$OUTPUT_DIR")" ./launch_train.sh --model $BASE_MODEL \ --output_dir $OUTPUT_DIR \ --data $DATA \ --num_gpu $NUM_GPU \ --num_epochs 2 \ --eagle_config eagle_config.json
70-72
: Also mkdir for export pathPrevents failures if parent dir is missing.
echo "==== [3/3] Exporting checkpoint to deployment format ====" EXPORT_PATH=export/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M) +mkdir -p "$(dirname "$EXPORT_PATH")" python export_hf_checkpoint.py --model_path $OUTPUT_DIR --export_path $EXPORT_PATH
25-41
: Minor: align flag name in comment with implementationThe comment says
--base-model
but code uses--base_model
.examples/speculative_decoding/README.md (11)
7-9
: Fix grammar and proper names in intro.Use articles and brand casing; current phrasing is awkward.
-This folder contains end-to-end runnable speculative decoding fine-tuning pipeline where Llama3.2-1B from huggingface is trained on Daring-Anteater dataset. - -This example focus on training with HF. To train with Megatron-LM, please refer to [this link](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt) in Megatron-LM repo. +This folder contains an end-to-end runnable speculative decoding fine‑tuning pipeline in which Llama‑3.2‑1B (Hugging Face) is trained on the Daring‑Anteater dataset. + +This example focuses on training with Hugging Face. To train with Megatron‑LM, see the [Megatron‑LM example](https://github.com/NVIDIA/Megatron-LM/tree/main/examples/post_training/modelopt).
65-71
: Avoid suggesting hardcoded API keys in examples.Remove the fake key or clearly denote an env var. This reduces the chance of users pasting secrets.
-pip install vllm -vllm serve meta-llama/Llama-3.2-1B-Instruct --api-key token-abc123 --port 8000 --tensor-parallel-size 1 +pip install vllm +# If your deployment requires an API key, set VLLM_API_KEY in the environment instead of hardcoding. +VLLM_API_KEY=... vllm serve meta-llama/Llama-3.2-1B-Instruct --port 8000 --tensor-parallel-size 1
90-96
: Tighten wording and fix minor grammar.-For eagle1 and eagle3 we provide an [default model architecture config](.../default_config.py#L18) in modelopt. User can overwrite default settings by providing additional json dict. In this example, we overwrite the `draft_vocab_size` by in `eagle_config.json`: +For EAGLE‑1 and EAGLE‑3 we provide a [default model architecture config](.../default_config.py#L18) in ModelOpt. You can override default settings by providing an additional JSON dict. In this example, we override `draft_vocab_size` in `eagle_config.json`:
100-108
: Fix typos (“tokenzier”, “hugginface”) and add clarity.-`main.py` provides a example for converting a base HF model for speculative decoding and training it. It consists of a few simple steps: -First, load base model and tokenzier from hugginface: +`main.py` provides an example for converting a base HF model for speculative decoding and training it. It consists of a few simple steps: +First, load the base model and tokenizer from Hugging Face:
118-121
: Typo in key name placeholder.-# overwrite config with custom config -config["eagle_architecture_config"].update({"<overwrite_kyes>": "<overwrite_values>"}) +# Overwrite config with custom config +config["eagle_architecture_config"].update({"<overwrite_keys>": "<overwrite_values>"})
131-136
: Fix typo and verify API alias.Spelling: “deocoding” → “decoding”. Also, please confirm
mtsp.convert
is the correct symbol/alias in this repo.-Then, we convert model to a speculative deocoding model: +Then, we convert the model to a speculative decoding model:
140-147
: Avoid private Trainer APIs; rely on Trainer for device placement.
trainer._move_model_to_device
is a private method and may change; Trainer already handles device placement. Also confirm the correct path forenable_huggingface_checkpointing()
.-trainer._move_model_to_device(model, trainer.args.device) - # Enable HF checkpointing so that the saved model will contain the speculative decoding module -mto.enable_huggingface_checkpointing() +from modelopt.torch import export as mto_export # adjust import as needed +mto_export.enable_huggingface_checkpointing()
152-161
: Fix spelling and tighten phrasing in launch instructions.-... along with a bash script to launch the training with huggingface accelrate in `launch_train.sh`, which can be runned by: +... along with a bash script to launch training with Hugging Face Accelerate in `launch_train.sh`, which can be run by:
189-197
: Minor YAML lead‑in punctuation and naming consistency.-To serve the checkpoint with trtllm, we can run trtllm-serve with: +To serve the checkpoint with TRT‑LLM, run `trtllm-serve`: ... -, - with `extra-llm-api-config.yml` being +with `extra-llm-api-config.yml`:
224-233
: Brand/style fixes in Support Matrix.-| LLAMA 2 | ✅ | ✅ | ✅ | -| LLAMA 3, 3.1 | ✅ | ✅ | ✅ | +| Llama 2 | ✅ | ✅ | ✅ | +| Llama 3, 3.1 | ✅ | ✅ | ✅ | ... -| QWen 1.5,2,2.5 | ✅ | ✅ | ✅ | +| Qwen 1.5, 2, 2.5 | ✅ | ✅ | ✅ |
236-238
: Use correct company casing (“NVIDIA”).-Ready-to-deploy speculation module checkpoints \[[🤗 Hugging Face - Nvidia TensorRT Model Optimizer Collection](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4)\] +Ready-to-deploy speculation module checkpoints \[[🤗 Hugging Face - NVIDIA TensorRT Model Optimizer Collection](https://huggingface.co/collections/nvidia/model-optimizer-66aa84f7966b3150262481a4)\]examples/speculative_decoding/eagle_config.json (1)
2-8
: Align rope_scaling schema with expected HF conventions (consider adding"type": "dynamic"
).Given the presence of
low_freq_factor
,high_freq_factor
, andoriginal_max_position_embeddings
, many HF configs expectrope_scaling.type="dynamic"
. If the exporter relies on HF-compatible fields, add the type or confirm your plugin translates this correctly.Apply this minimal diff if compatible with your pipeline:
"rope_scaling": { + "type": "dynamic", "factor": 32.0, "low_freq_factor": 1.0, "high_freq_factor": 4.0, "original_max_position_embeddings": 8192, "rope_type": "llama3" },
examples/speculative_decoding/server_generate.py (1)
56-56
: Make--system_prompt
truly optional to avoid odd joining behavior.With
nargs="+"
and default""
," ".join(args.system_prompt)
can behave unexpectedly. Prefer a list default.Apply:
-parser.add_argument("--system_prompt", nargs="+", type=str, default="", help="System prompt") +parser.add_argument("--system_prompt", nargs="*", type=str, default=[], help="System prompt")
system_prompt = " ".join(args.system_prompt)
will remain correct (empty string when not provided).Also applies to: 185-186
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (13)
examples/speculative_decoding/README.md
(2 hunks)examples/speculative_decoding/ar_validate.py
(2 hunks)examples/speculative_decoding/calibrate_draft_vocab.py
(2 hunks)examples/speculative_decoding/eagle_config.json
(1 hunks)examples/speculative_decoding/export_hf_checkpoint.py
(1 hunks)examples/speculative_decoding/launch.sh
(0 hunks)examples/speculative_decoding/main.py
(3 hunks)examples/speculative_decoding/server_generate.py
(1 hunks)examples/speculative_decoding/train_eagle3_and_export.sh
(1 hunks)modelopt/torch/export/plugins/__init__.py
(1 hunks)modelopt/torch/export/plugins/hf_spec_export.py
(1 hunks)modelopt/torch/export/unified_export_hf.py
(4 hunks)modelopt/torch/speculative/plugins/transformers.py
(11 hunks)
💤 Files with no reviewable changes (1)
- examples/speculative_decoding/launch.sh
🧰 Additional context used
🧬 Code graph analysis (6)
examples/speculative_decoding/export_hf_checkpoint.py (2)
modelopt/torch/export/unified_export_hf.py (1)
export_hf_checkpoint
(505-557)modelopt/torch/opt/plugins/huggingface.py (1)
enable_huggingface_checkpointing
(127-162)
examples/speculative_decoding/calibrate_draft_vocab.py (1)
modelopt/torch/speculative/utils.py (1)
calibrate_frequent_vocab
(31-45)
modelopt/torch/export/plugins/hf_spec_export.py (1)
modelopt/torch/speculative/plugins/transformers.py (1)
HFEagleModel
(333-1138)
modelopt/torch/export/unified_export_hf.py (1)
modelopt/torch/export/plugins/hf_spec_export.py (2)
rename_and_prune_if_spec_decoding
(55-80)set_config_if_spec_decoding
(83-151)
examples/speculative_decoding/main.py (1)
modelopt/torch/export/model_config.py (1)
max_position_embeddings
(603-605)
modelopt/torch/speculative/plugins/transformers.py (1)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
to
(115-123)
🪛 Shellcheck (0.10.0)
examples/speculative_decoding/train_eagle3_and_export.sh
[warning] 53-53: Declare and assign separately to avoid masking return values.
(SC2155)
🪛 LanguageTool
examples/speculative_decoding/README.md
[grammar] ~7-~7: There might be a mistake here.
Context: ...ntly improving throughput. This folder contains end-to-end runnable speculative decodin...
(QB_NEW_EN)
[grammar] ~7-~7: There might be a mistake here.
Context: ...Llama3.2-1B from huggingface is trained on Daring-Anteater dataset. This example ...
(QB_NEW_EN)
[grammar] ~9-~9: There might be a mistake here.
Context: ...e/main/examples/post_training/modelopt) in Megatron-LM repo. ## Contents <div al...
(QB_NEW_EN)
[grammar] ~15-~15: There might be a mistake here.
Context: ...tion** | Description | Jump To | | :------------: | :------------: | :---...
(QB_NEW_EN)
[grammar] ~16-~16: There might be a mistake here.
Context: ...---: | :------------: | :------------: | | Pre-Requisites | Required & optional d...
(QB_NEW_EN)
[grammar] ~17-~17: There might be a mistake here.
Context: ...ndencies | [Link] | | Simplified Workflow | Train, evaluate ...
(QB_NEW_EN)
[grammar] ~18-~18: There might be a mistake here.
Context: ...getting-started-simplified-workflow)] | | Complete Workflow | Full example with ...
(QB_NEW_EN)
[grammar] ~19-~19: Ensure spelling is correct
Context: ...rkflow | Full example with configurable traininig pipeline | [Link] ...
(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)
[grammar] ~19-~19: There might be a mistake here.
Context: ...pipeline | [Link] | | Support Matrix | Supported models for ...
(QB_NEW_EN)
[grammar] ~20-~20: There might be a mistake here.
Context: ...ing training | [Link] | | Speculation Module Checkpoints | View ...
(QB_NEW_EN)
[grammar] ~21-~21: There might be a mistake here.
Context: ...nk](#speculation-module-checkpoints)] | | Resources | Extra links to relevant re...
(QB_NEW_EN)
[grammar] ~50-~50: There might be a mistake here.
Context: ...datasets/nvidia/Daring-Anteater) dataset - Evaluates the acceptance rate on [MT-Ben...
(QB_NEW_EN)
[grammar] ~72-~72: There might be a mistake here.
Context: ...odels. Then, we generate conversations with base model and prompts from Daring-Ante...
(QB_NEW_EN)
[grammar] ~80-~80: There might be a mistake here.
Context: ...cabulary Compression We can optionally use smaller vocab size for the draft model ...
(QB_NEW_EN)
[grammar] ~90-~90: There might be a mistake here.
Context: ... User can overwrite default settings by providing additional json dict. In this example, ...
(QB_NEW_EN)
[grammar] ~100-~100: There might be a mistake here.
Context: ...g it. It consists of a few simple steps: First, load base model and tokenzier fro...
(QB_NEW_EN)
[grammar] ~101-~101: There might be a mistake here.
Context: ... consists of a few simple steps: First, load base model and tokenzier from hugginfac...
(QB_NEW_EN)
[grammar] ~101-~101: Ensure spelling is correct
Context: ...imple steps: First, load base model and tokenzier from hugginface: ```python model = tra...
(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)
[grammar] ~109-~109: There might be a mistake here.
Context: ...to your pretrained model>" ) ``` Then, load default eagle config and make necessary...
(QB_NEW_EN)
[grammar] ~109-~109: There might be a mistake here.
Context: ...`` Then, load default eagle config and make necessary overwrites: ```python # Load...
(QB_NEW_EN)
[grammar] ~131-~131: There might be a mistake here.
Context: ...ition_embeddings, } ) ``` Then, we convert model to a speculative deocoding model:...
(QB_NEW_EN)
[grammar] ~131-~131: Ensure spelling is correct
Context: ...Then, we convert model to a speculative deocoding model: ```python mtsp.convert(model, [...
(QB_NEW_EN_ORTHOGRAPHY_ERROR_IDS_1)
[grammar] ~137-~137: There might be a mistake here.
Context: ... training forward, making it compatible with HF trainer: ```python # Create a train...
(QB_NEW_EN)
[grammar] ~167-~167: There might be a mistake here.
Context: ...Q and QAT. ### Model Validation After training draft model, we can evaluate the saved ...
(QB_NEW_EN)
[grammar] ~226-~226: There might be a mistake here.
Context: ... | Model | Medusa | EAGLE1/2 | EAGLE3 | | :---: | :---: | :---: | :---: | | LLAM...
(QB_NEW_EN)
[grammar] ~227-~227: There might be a mistake here.
Context: ...GLE3 | | :---: | :---: | :---: | :---: | | LLAMA 2 | ✅ | ✅ | ✅ | | LLAMA 3, 3.1 |...
(QB_NEW_EN)
[grammar] ~228-~228: There might be a mistake here.
Context: ... :---: | :---: | | LLAMA 2 | ✅ | ✅ | ✅ | | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral |...
(QB_NEW_EN)
[grammar] ~229-~229: There might be a mistake here.
Context: ...✅ | ✅ | ✅ | | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ ...
(QB_NEW_EN)
[grammar] ~230-~230: There might be a mistake here.
Context: ....1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5 |...
(QB_NEW_EN)
[grammar] ~231-~231: There might be a mistake here.
Context: ...tral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5 | ✅ | ✅ | ✅ | ## Specu...
(QB_NEW_EN)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (1)
- GitHub Check: linux
🔇 Additional comments (18)
examples/speculative_decoding/ar_validate.py (4)
29-29
: LGTM! Improved default values for AR validation.The increase in
num_samples
from 20 to 80 provides better statistical reliability for AR validation results.
57-57
: LGTM! More reasonable default for validation steps.Changing the default
steps
from 1 to 3 provides better validation coverage while remaining computationally reasonable.
59-59
: LGTM! Adjusted OSL for better performance balance.Reducing the default
osl
from 100 to 32 should improve validation speed while maintaining reasonable sequence length coverage.
62-62
: LGTM! Consistent with function signature update.The CLI default now matches the updated function signature default of 80 samples.
examples/speculative_decoding/calibrate_draft_vocab.py (2)
31-35
: LGTM! Better separation of concerns.Moving
draft_vocab_size
from the config file to a direct CLI parameter provides clearer interface separation and makes the parameter more explicit.
55-55
: LGTM! Consistent with the CLI parameter change.The function call now correctly uses the CLI argument instead of reading from the config file.
examples/speculative_decoding/main.py (3)
50-56
: LGTM! Proper optional wandb integration.The graceful handling of wandb import ensures the script continues to work when wandb is not available, while enabling enhanced logging when it is present.
180-181
: LGTM! Essential configuration propagation.Propagating
max_position_embeddings
from the base model to the Eagle architecture config ensures deployment compatibility and proper model configuration alignment.
225-226
: LGTM! Conditional wandb logging.The conditional logging ensures AR validation metrics are captured when wandb is available without causing errors when it's not installed.
modelopt/torch/speculative/plugins/transformers.py (5)
688-689
: Good change: pass through extra kwargs.Allows future-proofing without breaking call sites.
791-804
: Base-model outputs fast-path looks good.Nice flexibility to consume teacher-provided tensors and skip recompute.
1079-1080
: Good device alignment for base_token.Avoids device mismatch in downstream concat.
1156-1157
: Good: ensure device match in AR validation.Prevents accidental CPU/GPU mismatch during concatenation.
1021-1029
: Verify no positional unpacking of ModelOutput
Ensure no downstream code destructures or uses.to_tuple()
assuming the previous tuple shape, as addingtrain_acc
changes its length and order.modelopt/torch/export/plugins/__init__.py (1)
22-24
: Re-export looks fine under transformers plugin guard.Keeps import-time side effects contained to environments with transformers installed.
examples/speculative_decoding/README.md (1)
41-55
: Nice addition: one‑liner workflow with training + export.Clear and actionable; the defaults link helps users understand what’s applied.
examples/speculative_decoding/eagle_config.json (1)
9-9
: Verifyinitializer_range
matches the base model’s setting.A mismatch can affect initialization and training stability; ensure 0.02 is consistent with the chosen base model config.
examples/speculative_decoding/server_generate.py (1)
49-49
: Useargparse.BooleanOptionalAction
for--chat
instead oftype=bool
Replace in examples/speculative_decoding/server_generate.py (around line 49):-parser.add_argument("--chat", default=True, type=bool, help="Use chat mode") +parser.add_argument( + "--chat", + action=argparse.BooleanOptionalAction, + default=True, + help="Use chat mode", +)Confirm your minimum Python version is ≥ 3.9 so
BooleanOptionalAction
exists, and update any docs or scripts to reference--no-chat
where needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
🧹 Nitpick comments (4)
examples/speculative_decoding/launch_train.sh (4)
95-97
: Sanitize OUTPUT_DIR model basename to avoid filesystem issues.Some model names contain characters problematic for paths. Sanitize the basename before composing
OUTPUT_DIR
.MODEL_BASENAME=$(basename "$MODEL") -OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME}-$(date +%Y%m%d_%H%M)"} +MODEL_BASENAME_SAFE="${MODEL_BASENAME//[^a-zA-Z0-9._-]/_}" +OUTPUT_DIR=${OUTPUT_DIR:-"ckpts/${MODEL_BASENAME_SAFE}-$(date +%Y%m%d_%H%M)"}
77-83
: Don’t hard-fail on unknown flags; forward them to the training script.Exiting on first unknown arg prevents users from passing through Hugging Face/Accelerate options. Collect and forward instead.
- *) - >&2 printf "Error: Invalid argument ${1#*=}\n" - exit 1 - ;; + *) + EXTRA_ARGS+=("$1") + ;;Add initialization near the top (before the while-loop):
set -eo pipefail + +EXTRA_ARGS=()
130-153
: Plumb through extra args to CMD.Forward any collected unknown args to
main.py
.--data_path $DATA \ - $SPECULATIVE_ARGS + $SPECULATIVE_ARGS \ + ${EXTRA_ARGS[@]:-}
141-142
: Make--do_eval
overridable while defaulting to False.Hardcoding
--do_eval False
blocks standalone use; keep False by default but allow override via a parsed--do_eval
.- --do_eval False \ + --do_eval ${DO_EVAL:-False} \Add a parser case (outside this hunk) to honor
--do_eval
:# In the case/esac parser: --do_eval*) if [[ "$1" != *=* ]]; then shift; fi DO_EVAL="${1#*=}" ;;
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
examples/speculative_decoding/launch_train.sh
(4 hunks)modelopt/torch/speculative/plugins/transformers.py
(11 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/speculative/plugins/transformers.py
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (2)
examples/speculative_decoding/launch_train.sh (2)
151-152
: Validate required data path early.If
DATA
is empty,--data_path
will be passed without a value and the run will fail late. Fail fast with a clear message or provide a sensible default.# Add before building CMD: if [[ -z "${DATA:-}" ]]; then >&2 echo "Error: --data <path> is required (set via --data=...)." exit 1 fi
128-129
: Good call disabling tokenizers parallelism.This avoids noisy warnings and thread oversubscription in tokenization-heavy phases.
This will launch a vllm server (sglang is also available) on each node. Each node will work through 10 shards of data (10\*max_lines_per_shard number of samples). | ||
In this case, the first 40 shards of data will be processed. | ||
To process the next 40 shards | ||
This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft vocabs to full vocab that will be read by our draft model later. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Better to explain how this mapping works. i.e., tokens = draft_tokens + d2t[draft_tokens]
cf17899
to
56fafda
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 5
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
examples/speculative_decoding/README.md (2)
137-149
: Definemtsp
/mto
imports before use.
Make the snippet self-contained.-```python -mtsp.convert(model, [("eagle", config)]) -``` -... -```python -# Enable HF checkpointing so that the saved model will contain the speculative decoding module -mto.enable_huggingface_checkpointing() -``` +```python +import modelopt.torch.speculative as mtsp +import modelopt.torch.export.unified_export_hf as mto + +mtsp.convert(model, [("eagle", config)]) +... +# Enable HF checkpointing so that the saved model will contain the speculative decoding module +mto.enable_huggingface_checkpointing() +```
143-151
: Avoid private Trainer API and undefined variable.
_move_model_to_device
is private; Trainer handles device placement.checkpoint
isn’t defined; omit unless documented.-trainer._move_model_to_device(model, trainer.args.device) - # Enable HF checkpointing so that the saved model will contain the speculative decoding module mto.enable_huggingface_checkpointing() -trainer.train(resume_from_checkpoint=checkpoint) +trainer.train() trainer.save_state()
♻️ Duplicate comments (4)
modelopt/torch/export/unified_export_hf.py (1)
529-545
: Guard spec-decoding transforms when ModelOpt state is absentCalling spec-decoding transforms unconditionally can break vanilla HF exports. Add a safe gate.
- post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict) + if hasattr(model, "_modelopt_state") and getattr(model, "_modelopt_state") is not None: + post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict) @@ - config_data = set_config_if_spec_decoding(model, config_data) + if hasattr(model, "_modelopt_state") and getattr(model, "_modelopt_state") is not None: + config_data = set_config_if_spec_decoding(model, config_data)examples/speculative_decoding/README.md (3)
19-20
: Broken anchors from prior commit are now correct.
“Complete Workflow” and “Support Matrix” Jump To targets are fixed.
172-174
: Code-fence language fixed.
CLI is now fenced as bash, not python.
181-185
: Export section fences and copy now correct.
Language tag and “format” typo are fixed.
🧹 Nitpick comments (16)
examples/speculative_decoding/SLURM_prepare_data.md (1)
9-10
: Quote job name; normalize spacingAvoid parsing issues with ':' in job name; trim double spaces.
-salloc -N4 -A <account> -p <partition> -J <account>-synthetic:data-gen -t 120 +salloc -N4 -A <account> -p <partition> -J "<account>-synthetic:data-gen" -t 120modelopt/torch/export/unified_export_hf.py (1)
517-526
: Path handling nit: prefer Path operations over f-stringsImproves readability and OS-compatibility.
- with open(f"{export_dir}/hf_quant_config.json", "w") as file: + with open(Path(export_dir) / "hf_quant_config.json", "w") as file: json.dump(hf_quant_config, file, indent=4) @@ - original_config = f"{export_dir}/config.json" + original_config = Path(export_dir) / "config.json" config_data = {} @@ - with open(original_config, "w") as file: + with open(original_config, "w") as file: json.dump(config_data, file, indent=4)Also applies to: 536-547
modelopt/torch/speculative/plugins/transformers.py (4)
689-717
: Consistent draft-vocab remap policy_base path remaps logits only during training; teacher-provided path remaps unconditionally. Ensure both paths align with how downstream loss expects logits to be in draft space.
- if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size and self.training: + if self.eagle_config.draft_vocab_size != self.eagle_config.vocab_size: assert hasattr(self.eagle_module, "d2t"), "d2t buffer not initialized" base_model_logits = self._map_logits_to_draft_vocab(base_model_logits)If inference truly needs full-vocab logits, gate at call site instead.
720-726
: Bounds/correctness checks and caching for mappingAdd sanity checks and cache the reverse index to avoid recomputing per call.
- def _map_logits_to_draft_vocab(self, full_logits): - reverse_mapping = ( - torch.arange(len(self.eagle_module.d2t)).to(self.eagle_module.d2t.device) - + self.eagle_module.d2t - ) - return full_logits[:, :, reverse_mapping] + def _map_logits_to_draft_vocab(self, full_logits): + d2t = self.eagle_module.d2t + draft = d2t.numel() + if not hasattr(self, "_reverse_draft_index") or self._reverse_draft_index.numel() != draft: + self._reverse_draft_index = torch.arange(draft, device=d2t.device) + d2t + # Assert indices are in range + assert self._reverse_draft_index.max().item() < full_logits.size(-1), "draft→full index OOB" + return full_logits.index_select(-1, self._reverse_draft_index)
826-831
: Aux hidden states: validate presence/shape in teacher pathIf
use_aux_hidden_state
is True and teacher path is used, assertaux_hidden_states
is provided and has expected last-dim = len(layer_ids)*hidden_size to fail fast with a clear message.- if "base_model_outputs" in kwargs: - aux_hidden_states = kwargs["base_model_outputs"]["aux_hidden_states"] + if "base_model_outputs" in kwargs: + aux_hidden_states = kwargs["base_model_outputs"].get("aux_hidden_states") + assert aux_hidden_states is not None, "aux_hidden_states required for EAGLE-3 teacher path"
790-803
: Clarify DynamicCache initialization and document aux_hidden_states requirement
Thepast_key_values = None
branch is already handled byDynamicCache.from_legacy_cache(None)
, so no extra guard is required; consider adding an inline comment (# will initialize a fresh DynamicCache below
) for clarity. Update the EAGLE-3 docs to specify that whenuse_aux_hidden_state
is enabled,base_model_outputs
must include anaux_hidden_states
tensor shaped to matcheagle_aux_hidden_state_layer_ids
.examples/speculative_decoding/README.md (10)
28-28
: Use consistent capitalization: “ModelOpt”.
Brand is capitalized elsewhere; align here.-Install Modelopt with `hf` dependencies and other requirements for this example: +Install ModelOpt with `hf` dependencies and other requirements for this example:
47-47
: Capitalize “ModelOpt”.-This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in Modelopt. +This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in ModelOpt.
83-90
: Explain d2t mapping with a concrete relation.
Add a short sentence showing how the mapping composes tokens.-This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft vocabs to full vocab that will be read by our draft model later. +This produces a `d2t.pt` file in `save_dir`, a mapping from draft vocab to full vocab used by the draft model. Conceptually: `tokens_full = tokens_draft + d2t[tokens_draft]`.
101-101
: Capitalize “ModelOpt”.-### Training Draft Model with Modelopt +### Training Draft Model with ModelOpt
106-110
: Import missing dependency in snippet.
Readers will copy/paste this; include the import.-```python -model = transformers.AutoModelForCausalLM.from_pretrained( +```python +import transformers +model = transformers.AutoModelForCausalLM.from_pretrained(
166-170
: Capitalize “ModelOpt”.-The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT. +The saved ModelOpt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.-After training draft model, we can evaluate the saved modelopt checkpoint on MT-bench by: +After training the draft model, evaluate the saved ModelOpt checkpoint on MT‑Bench with:
227-235
: Polish model names and add “last updated”.
Use canonical casing and note when the matrix was last validated.-| LLAMA 2 | ✅ | ✅ | ✅ | -| LLAMA 3, 3.1 | ✅ | ✅ | ✅ | +| Llama 2 | ✅ | ✅ | ✅ | +| Llama 3, 3.1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | -| QWen 1.5,2,2.5 | ✅ | ✅ | ✅ | +| Qwen 1.5, 2, 2.5 | ✅ | ✅ | ✅ |Optionally add a line above the table: “Last validated: 2025‑09‑05.”
93-99
: Clarify EAGLE-1 vs EAGLE-3 default configs and update link
- In the README example, call out which settings enable EAGLE-3 (use_aux_hidden_state, eagle_aux_hidden_state_layer_ids, use_mtp_layernorm, eagle_disable_moe, eagle_hidden_state_distillation, etc.) and update the URL to point directly at the EagleConfig default JSON in the ModelOpt API docs.
65-71
: vLLM flags and quantization option are up-to-date
--api-key
(orVLLM_API_KEY
) is the current auth flag.--quantization=modelopt
has been supported since v0.6.5.Consider showing environment-based auth (
VLLM_API_KEY
/BASE_URL
) in examples instead of embedding tokens.
194-215
: Update README example to match currenttrtllm-serve
flags and YAML schema
- Confirm the CLI flag
--extra_llm_api_options
is still supported in your installed version (it appears in the latest docs).- Under
speculative_config
, use the up-to-date keys:speculative_decoding_mode
(aliasdecoding_type
),max_draft_len
,speculative_model_dir
, etc.- Under
kv_cache_config
, include valid fields such asenable_block_reuse
,enable_partial_reuse
,copy_on_partial_reuse
,free_gpu_memory_fraction
, and any other keys required by your target release.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (7)
examples/speculative_decoding/README.md
(2 hunks)examples/speculative_decoding/SLURM_prepare_data.md
(1 hunks)examples/speculative_decoding/export_hf_checkpoint.py
(1 hunks)examples/speculative_decoding/train_eagle3_and_export.sh
(1 hunks)modelopt/torch/export/plugins/hf_spec_export.py
(1 hunks)modelopt/torch/export/unified_export_hf.py
(4 hunks)modelopt/torch/speculative/plugins/transformers.py
(11 hunks)
🚧 Files skipped from review as they are similar to previous changes (3)
- examples/speculative_decoding/export_hf_checkpoint.py
- modelopt/torch/export/plugins/hf_spec_export.py
- examples/speculative_decoding/train_eagle3_and_export.sh
🧰 Additional context used
🪛 LanguageTool
examples/speculative_decoding/README.md
[grammar] ~21-~21: There might be a mistake here.
Context: ...nk](#speculation-module-checkpoints)] | | Resources | Extra links to relevant re...
(QB_NEW_EN)
[grammar] ~50-~50: There might be a mistake here.
Context: ...datasets/nvidia/Daring-Anteater) dataset - Evaluates the acceptance rate on [MT-Ben...
(QB_NEW_EN)
[grammar] ~71-~71: There might be a mistake here.
Context: ...odels. Then, we generate conversations with base model and prompts from Daring-Ante...
(QB_NEW_EN)
[grammar] ~83-~83: There might be a mistake here.
Context: ...cabulary Compression We can optionally use smaller vocab size for the draft model ...
(QB_NEW_EN)
[grammar] ~103-~103: There might be a mistake here.
Context: ...g it. It consists of a few simple steps: First, load the base model and tokenizer...
(QB_NEW_EN)
[grammar] ~112-~112: There might be a mistake here.
Context: ...to your pretrained model>" ) ``` Then, load default eagle config and make necessary...
(QB_NEW_EN)
[grammar] ~112-~112: There might be a mistake here.
Context: ...`` Then, load default eagle config and make necessary overwrites: ```python # Load...
(QB_NEW_EN)
[grammar] ~134-~134: There might be a mistake here.
Context: ...ition_embeddings, } ) ``` Then, we convert model to a speculative decoding model: ...
(QB_NEW_EN)
[grammar] ~140-~140: There might be a mistake here.
Context: ... training forward, making it compatible with HF trainer: ```python # Create a train...
(QB_NEW_EN)
[grammar] ~170-~170: There might be a mistake here.
Context: ...Q and QAT. ### Model Validation After training draft model, we can evaluate the saved ...
(QB_NEW_EN)
[grammar] ~229-~229: There might be a mistake here.
Context: ... | Model | Medusa | EAGLE1/2 | EAGLE3 | | :---: | :---: | :---: | :---: | | LLAM...
(QB_NEW_EN)
[grammar] ~230-~230: There might be a mistake here.
Context: ...GLE3 | | :---: | :---: | :---: | :---: | | LLAMA 2 | ✅ | ✅ | ✅ | | LLAMA 3, 3.1 |...
(QB_NEW_EN)
[grammar] ~231-~231: There might be a mistake here.
Context: ... :---: | :---: | | LLAMA 2 | ✅ | ✅ | ✅ | | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral |...
(QB_NEW_EN)
[grammar] ~232-~232: There might be a mistake here.
Context: ...✅ | ✅ | ✅ | | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ ...
(QB_NEW_EN)
[grammar] ~233-~233: There might be a mistake here.
Context: ....1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5 |...
(QB_NEW_EN)
[grammar] ~234-~234: There might be a mistake here.
Context: ...tral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5 | ✅ | ✅ | ✅ | ## Specu...
(QB_NEW_EN)
examples/speculative_decoding/SLURM_prepare_data.md
[grammar] ~3-~3: There might be a mistake here.
Context: ...eneration we provide some SLURM support. Assuming a $SLURM_JOB_ID
is present an...
(QB_NEW_EN)
[grammar] ~24-~24: There might be a mistake here.
Context: ...generate.pyand
distributed_generate`. This will launch a vllm server (sglang i...
(QB_NEW_EN)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (6)
modelopt/torch/export/unified_export_hf.py (1)
30-33
: LGTM: clean plugin importImporting spec-decoding helpers here aligns the export flow with plugins.
modelopt/torch/speculative/plugins/transformers.py (2)
1082-1083
: LGTM: device alignment fixMoving
base_token
toinput_ids.device
avoids cross-device concat errors.
1160-1160
: LGTM: device-safe concatenation in AR validationEnsures both tensors share device during growth loop.
examples/speculative_decoding/README.md (3)
5-5
: Intro definition reads well; keep.
Clear and technically accurate overview of speculative decoding.
239-241
: Checkpoint links look good.
Clear, actionable pointers to deployable artifacts.
79-79
: SLURM guide file present
The referencedSLURM_prepare_data.md
exists inexamples/speculative_decoding/
—no link changes needed.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
modelopt/torch/export/unified_export_hf.py (1)
520-548
: Guard quantization config writes
In modelopt/torch/export/unified_export_hf.py (lines 520–548), wrap the saving of hf_quant_config.json and the insertion of "quantization_config" into config.json behind a check for applied quantization (e.g.if hf_quant_config:
), so vanilla HF exports don’t emit empty quant artifacts.modelopt/torch/speculative/plugins/transformers.py (1)
1154-1160
: Avoid hard-coding CUDA device in validation.Use the input tensor’s device; current_device() breaks on CPU-only or different device placement.
- input_ids = copy.deepcopy(input_ids).to(torch.cuda.current_device()) + input_ids = copy.deepcopy(input_ids).to(input_ids.device) @@ - input_ids = torch.cat((input_ids, input_id.to(input_ids.device)), dim=-1) + input_ids = torch.cat((input_ids, input_id.to(input_ids.device)), dim=-1)
♻️ Duplicate comments (4)
modelopt/torch/export/unified_export_hf.py (2)
544-544
: Also guard spec-config transform to avoid AttributeError.Mirror the pruning guard for config transform.
(Implemented in the diff above.)
529-530
: Guard spec-decoding state-dict pruning (works on vanilla HF models).Call pruning only when ModelOpt state exists.
- post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict) + if hasattr(model, "_modelopt_state"): + post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)modelopt/torch/export/plugins/hf_spec_export.py (2)
87-95
: Guard _modelopt_state access to avoid crashes on vanilla HF.Directly indexing model._modelopt_state can raise AttributeError. Use the same guard pattern as in rename_and_prune_if_spec_decoding.
-def set_config_if_spec_decoding(model: nn.Module, config_data: dict): +def set_config_if_spec_decoding(model: nn.Module, config_data: dict): @@ - if len(model._modelopt_state) != 1 or model._modelopt_state[0][0] != "eagle": - # return as is - return config_data + opt_modes = getattr(model, "_modelopt_state", None) + if ( + not isinstance(opt_modes, (list, tuple)) + or len(opt_modes) != 1 + or opt_modes[0][0] != "eagle" + ): + return config_data
95-155
: Merge with existing config instead of overwriting it.Dropping unknown keys can remove fields like quantization_config; merge template into incoming config and deep-merge eagle_config.
- return template_config + # Merge: preserve existing fields while overriding with official template values. + merged = {**config_data, **template_config} + merged["eagle_config"] = { + **config_data.get("eagle_config", {}), + **template_config["eagle_config"], + } + return merged
🧹 Nitpick comments (10)
examples/speculative_decoding/SLURM_prepare_data.md (4)
3-5
: Tighten wording and fix grammar.Minor clarity nits.
-For basic parallelization of synthetic data generation we provide some SLURM support. -Assuming a `$SLURM_JOB_ID` is present and nodes, n1, n2, n3, n4 are selected the following is achievable. +For basic parallelization of synthetic data generation, we provide SLURM support. +Assuming `$SLURM_JOB_ID` is present and nodes n1, n2, n3, n4 are allocated, the following is achievable.
12-16
: Polish sentence casing.-Create shards of some given size +Create shards of a given size
24-26
: Capitalize framework names.-This will launch a vllm server (sglang is also available) on each node. Each node will work through 10 shards of data (10*max_lines_per_shard number of samples). +This launches a vLLM server (SGLang is also available) on each node. Each node will work through 10 shards of data (10*max_lines_per_shard samples).
29-31
: Keep argument list consistent between runs.The second launch example omits the system prompt argument; include it or add a note that it’s optional to avoid confusion.
-bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 40 10 n1,n2,n3,n4 +bash distributed_generate/launch.sh $SLURM_JOB_ID vllm TinyLlama/TinyLlama-1.1B-Chat-v1.0 /data/train/ /data/output /scripts/ 40 10 n1,n2,n3,n4 "\"You are a helpful assistant.\""examples/speculative_decoding/README.md (5)
18-18
: Capitalize EAGLE and tighten phrasing.-| Simplified Workflow | Train, evaluate, and export eagle model with one-line command | \[[Link](#getting-started-simplified-workflow)\] | +| Simplified Workflow | Train, evaluate, and export the EAGLE model with a one‑line command | \[[Link](#getting-started-simplified-workflow)\] |
35-39
: Add Git LFS note for dataset clone.HF datasets via git require LFS; add a one‑liner to reduce user friction.
```bash +git lfs install git clone https://huggingface.co/datasets/nvidia/Daring-Anteater
--- `69-69`: **Clarify note reads better as a flag tip.** ```diff -Note: Add `--quantization=modelopt` flag for quantized models. +Tip: Add `--quantization=modelopt` when serving quantized models.
101-101
: Consistent branding: ModelOpt.-### (Optional) Configuring Draft Model +### (Optional) Configuring Draft Model (ModelOpt)
227-235
: Model names casing and minor formatting.-| LLAMA 2 | ✅ | ✅ | ✅ | -| LLAMA 3, 3.1 | ✅ | ✅ | ✅ | +| Llama 2 | ✅ | ✅ | ✅ | +| Llama 3, 3.1 | ✅ | ✅ | ✅ |modelopt/torch/speculative/plugins/transformers.py (1)
713-717
: Guard d2t usage with a clearer error.Assertion is fine, but a ValueError is more actionable when draft vocab remapping is expected.
- assert hasattr(self.eagle_module, "d2t"), "d2t buffer not initialized" + if not hasattr(self.eagle_module, "d2t"): + raise ValueError("Draft‑vocab remapping requested but eagle_module.d2t is not initialized")
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (7)
examples/speculative_decoding/README.md
(2 hunks)examples/speculative_decoding/SLURM_prepare_data.md
(1 hunks)examples/speculative_decoding/export_hf_checkpoint.py
(1 hunks)examples/speculative_decoding/train_eagle3_and_export.sh
(1 hunks)modelopt/torch/export/plugins/hf_spec_export.py
(1 hunks)modelopt/torch/export/unified_export_hf.py
(4 hunks)modelopt/torch/speculative/plugins/transformers.py
(11 hunks)
🚧 Files skipped from review as they are similar to previous changes (2)
- examples/speculative_decoding/train_eagle3_and_export.sh
- examples/speculative_decoding/export_hf_checkpoint.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
examples/speculative_decoding/README.md
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, the EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG configurations in config.py are already deepcopied, so additional deepcopy calls are not needed when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
examples/speculative_decoding/README.md
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, EAGLE1_DEFAULT_CFG in config.py already uses deepcopy(default_eagle_config) in its definition, making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
examples/speculative_decoding/README.md
🧬 Code graph analysis (3)
modelopt/torch/export/plugins/hf_spec_export.py (1)
modelopt/torch/speculative/plugins/transformers.py (1)
HFEagleModel
(333-1141)
modelopt/torch/export/unified_export_hf.py (2)
modelopt/torch/export/plugins/hf_spec_export.py (2)
rename_and_prune_if_spec_decoding
(54-84)set_config_if_spec_decoding
(87-155)modelopt/torch/export/convert_hf_config.py (1)
convert_hf_quant_config_format
(21-117)
modelopt/torch/speculative/plugins/transformers.py (1)
modelopt/torch/quantization/qtensor/base_qtensor.py (1)
to
(115-123)
🪛 LanguageTool
examples/speculative_decoding/README.md
[grammar] ~21-~21: There might be a mistake here.
Context: ...nk](#speculation-module-checkpoints)] | | Resources | Extra links to relevant re...
(QB_NEW_EN)
[grammar] ~50-~50: There might be a mistake here.
Context: ...datasets/nvidia/Daring-Anteater) dataset - Evaluates the acceptance rate on [MT-Ben...
(QB_NEW_EN)
[grammar] ~71-~71: There might be a mistake here.
Context: ...odels. Then, we generate conversations with base model and prompts from Daring-Ante...
(QB_NEW_EN)
[grammar] ~83-~83: There might be a mistake here.
Context: ...cabulary Compression We can optionally use smaller vocab size for the draft model ...
(QB_NEW_EN)
[grammar] ~103-~103: There might be a mistake here.
Context: ...g it. It consists of a few simple steps: First, load the base model and tokenizer...
(QB_NEW_EN)
[grammar] ~112-~112: There might be a mistake here.
Context: ...to your pretrained model>" ) ``` Then, load default eagle config and make necessary...
(QB_NEW_EN)
[grammar] ~112-~112: There might be a mistake here.
Context: ...`` Then, load default eagle config and make necessary overwrites: ```python # Load...
(QB_NEW_EN)
[grammar] ~134-~134: There might be a mistake here.
Context: ...ition_embeddings, } ) ``` Then, we convert model to a speculative decoding model: ...
(QB_NEW_EN)
[grammar] ~140-~140: There might be a mistake here.
Context: ... training forward, making it compatible with HF trainer: ```python # Create a train...
(QB_NEW_EN)
[grammar] ~170-~170: There might be a mistake here.
Context: ...Q and QAT. ### Model Validation After training draft model, we can evaluate the saved ...
(QB_NEW_EN)
[grammar] ~229-~229: There might be a mistake here.
Context: ... | Model | Medusa | EAGLE1/2 | EAGLE3 | | :---: | :---: | :---: | :---: | | LLAM...
(QB_NEW_EN)
[grammar] ~230-~230: There might be a mistake here.
Context: ...GLE3 | | :---: | :---: | :---: | :---: | | LLAMA 2 | ✅ | ✅ | ✅ | | LLAMA 3, 3.1 |...
(QB_NEW_EN)
[grammar] ~231-~231: There might be a mistake here.
Context: ... :---: | :---: | | LLAMA 2 | ✅ | ✅ | ✅ | | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral |...
(QB_NEW_EN)
[grammar] ~232-~232: There might be a mistake here.
Context: ...✅ | ✅ | ✅ | | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ ...
(QB_NEW_EN)
[grammar] ~233-~233: There might be a mistake here.
Context: ....1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5 |...
(QB_NEW_EN)
[grammar] ~234-~234: There might be a mistake here.
Context: ...tral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5 | ✅ | ✅ | ✅ | ## Specu...
(QB_NEW_EN)
examples/speculative_decoding/SLURM_prepare_data.md
[grammar] ~3-~3: There might be a mistake here.
Context: ...eneration we provide some SLURM support. Assuming a $SLURM_JOB_ID
is present an...
(QB_NEW_EN)
[grammar] ~24-~24: There might be a mistake here.
Context: ...generate.pyand
distributed_generate`. This will launch a vllm server (sglang i...
(QB_NEW_EN)
🔇 Additional comments (3)
modelopt/torch/export/plugins/hf_spec_export.py (1)
54-66
: Tolerate models without _modelopt_state.rename_and_prune_if_spec_decoding is already guarded—good.
modelopt/torch/speculative/plugins/transformers.py (2)
680-692
: Forward signature change looks good.Accepts **kwargs for extensibility while retaining existing args.
1022-1030
: No downstream consumers oftrain_acc
found. Merging is safe.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 3
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (1)
examples/speculative_decoding/README.md (1)
145-153
: Don’t rely on private Trainer APIs; fix undefinedcheckpoint
.Trainer will move the model to device during training. Also
checkpoint
is undefined.-trainer._move_model_to_device(model, trainer.args.device) @@ -# Enable HF checkpointing so that the saved model will contain the speculative decoding module +# Enable HF checkpointing so that the saved model will contain the speculative decoding module mto.enable_huggingface_checkpointing() -trainer.train(resume_from_checkpoint=checkpoint) +trainer.train() trainer.save_state() trainer.save_model("<path to the output directory>")
♻️ Duplicate comments (2)
examples/speculative_decoding/README.md (2)
15-22
: TOC anchors now resolve correctly.Broken anchors/typo from earlier review are fixed.
172-185
: Code fences now correctly use bash and grammar is fixed.Prior issues are resolved.
🧹 Nitpick comments (4)
examples/speculative_decoding/README.md (4)
49-53
: Avoid line-number deep links that can drift.Linking to ...default_config.py#L18 is brittle. Point to the file (or a permalink) instead.
-- Initializes the draft model with [default settings](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/eagle/default_config.py#L18) +- Initializes the draft model with [default settings](https://github.com/NVIDIA/TensorRT-Model-Optimizer/blob/main/modelopt/torch/speculative/eagle/default_config.py)
83-90
: Great fix on the flag; tighten wording and clarify mapping.The extra dashes issue is resolved. Suggest minor wording + add a one-liner showing how d2t is applied.
-We can optionally use smaller vocab size for the draft model for faster training and inference. E.g. Llama3.2-1B has a vocab size of 128256. In this example, we construct a draft vocab mapping of size 32k by finding the most commonly appeared vocabs in our training set: +We can optionally use a smaller draft vocabulary to speed up training/inference. For example, Llama‑3.2‑1B uses a 128,256‑token vocabulary. Here we build a 32k draft‑to‑target mapping by selecting the most frequent tokens in the training set: @@ -This will produce a `d2t.pt` file in `save_dir`, which is the mapping from draft vocabs to full vocab that will be read by our draft model later. +This produces `d2t.pt` in `save_dir`, a draft‑to‑target mapping the draft model will load later (applied roughly as: accepted_tokens = draft_tokens + d2t[draft_tokens]).
93-99
: Link stability.config.py#L37 may drift; consider linking to the file (or a permalink).
166-167
: Unify brand capitalization: ModelOpt.Use “ModelOpt” consistently.
-The saved modelopt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT. +The saved ModelOpt checkpoint is similar in architecture to HF models. It can be further optimized through **ModelOpt**, e.g., PTQ and QAT.
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
💡 Knowledge Base configuration:
- MCP integration is disabled by default for public repositories
- Jira integration is disabled by default for public repositories
- Linear integration is disabled by default for public repositories
You can enable these sources in your CodeRabbit configuration.
📒 Files selected for processing (2)
examples/speculative_decoding/README.md
(2 hunks)modelopt/torch/export/unified_export_hf.py
(4 hunks)
🚧 Files skipped from review as they are similar to previous changes (1)
- modelopt/torch/export/unified_export_hf.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
examples/speculative_decoding/README.md
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, the EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG configurations in config.py are already deepcopied, so additional deepcopy calls are not needed when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
examples/speculative_decoding/README.md
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, EAGLE1_DEFAULT_CFG in config.py already uses deepcopy(default_eagle_config) in its definition, making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
examples/speculative_decoding/README.md
🪛 LanguageTool
examples/speculative_decoding/README.md
[grammar] ~21-~21: There might be a mistake here.
Context: ...nk](#speculation-module-checkpoints)] | | Resources | Extra links to relevant re...
(QB_NEW_EN)
[grammar] ~50-~50: There might be a mistake here.
Context: ...datasets/nvidia/Daring-Anteater) dataset - Evaluates the acceptance rate on [MT-Ben...
(QB_NEW_EN)
[grammar] ~71-~71: There might be a mistake here.
Context: ...odels. Then, we generate conversations with base model and prompts from Daring-Ante...
(QB_NEW_EN)
[grammar] ~83-~83: There might be a mistake here.
Context: ...cabulary Compression We can optionally use smaller vocab size for the draft model ...
(QB_NEW_EN)
[grammar] ~103-~103: There might be a mistake here.
Context: ...g it. It consists of a few simple steps: First, load the base model and tokenizer...
(QB_NEW_EN)
[grammar] ~112-~112: There might be a mistake here.
Context: ...to your pretrained model>" ) ``` Then, load default eagle config and make necessary...
(QB_NEW_EN)
[grammar] ~112-~112: There might be a mistake here.
Context: ...`` Then, load default eagle config and make necessary overwrites: ```python # Load...
(QB_NEW_EN)
[grammar] ~134-~134: There might be a mistake here.
Context: ...ition_embeddings, } ) ``` Then, we convert model to a speculative decoding model: ...
(QB_NEW_EN)
[grammar] ~140-~140: There might be a mistake here.
Context: ... training forward, making it compatible with HF trainer: ```python # Create a train...
(QB_NEW_EN)
[grammar] ~170-~170: There might be a mistake here.
Context: ...Q and QAT. ### Model Validation After training draft model, we can evaluate the saved ...
(QB_NEW_EN)
[grammar] ~229-~229: There might be a mistake here.
Context: ... | Model | Medusa | EAGLE1/2 | EAGLE3 | | :---: | :---: | :---: | :---: | | LLAM...
(QB_NEW_EN)
[grammar] ~230-~230: There might be a mistake here.
Context: ...GLE3 | | :---: | :---: | :---: | :---: | | LLAMA 2 | ✅ | ✅ | ✅ | | LLAMA 3, 3.1 |...
(QB_NEW_EN)
[grammar] ~231-~231: There might be a mistake here.
Context: ... :---: | :---: | | LLAMA 2 | ✅ | ✅ | ✅ | | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral |...
(QB_NEW_EN)
[grammar] ~232-~232: There might be a mistake here.
Context: ...✅ | ✅ | ✅ | | LLAMA 3, 3.1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ ...
(QB_NEW_EN)
[grammar] ~233-~233: There might be a mistake here.
Context: ....1 | ✅ | ✅ | ✅ | | Mistral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5 |...
(QB_NEW_EN)
[grammar] ~234-~234: There might be a mistake here.
Context: ...tral | ✅ | ✅ | ✅ | | Phi 3 | ✅ | ✅ | ✅ | | QWen 1.5,2,2.5 | ✅ | ✅ | ✅ | ## Specu...
(QB_NEW_EN)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (6)
examples/speculative_decoding/README.md (6)
5-6
: Clear, technically accurate intro.The γ/α explanation is concise and correct.
227-236
: Support matrix may be outdated.A reviewer noted more models are now supported; update once verified on this HF path.
239-241
: Nice addition.Checkpoint collection link and deployment note are helpful.
71-75
: No action needed: script path is correct
server_generate.py is present at examples/speculative_decoding/server_generate.py, matching the README’s invocation when run from that directory.
194-205
: Validate TRT-LLM serve flags and adjust wording.Most CLI flags (
--backend pytorch
,--max_batch_size
,--max_num_tokens
,--max_seq_len
,--extra_llm_api_options
) and YAML keys (enable_attention_dp
,disable_overlap_scheduler
,cuda_graph_config
,speculative_config.*
) match current TRT-LLM docs. I couldn’t locateenable_autotuner
—please confirm it’s supported in your target release or remove it.-, with `extra-llm-api-config.yml` being +With `extra-llm-api-config.yml`:
62-70
: vLLM flags confirmed
--api-key
is optional for local runs; if provided (or viaVLLM_API_KEY
), the server will enforce the specified key(s) in request headers.--quantization=modelopt
is supported as one of vLLM’s quantization backends.
ed1f5a3
to
240eac1
Compare
This needs to be addressed #293 (comment). Otherwise, LGTM. |
023b27a
to
faf36b7
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 2
♻️ Duplicate comments (11)
examples/speculative_decoding/launch_train.sh (1)
87-91
: Guard torch detection and prevent division-by-zero; prefer NUM_GPU when available.
python -c 'import torch'
can fail underset -e
, andtorch.cuda.device_count()
may return 0, yielding a division-by-zero orDEFAULT_SAVE_STEPS=0
. Use${NUM_GPU}
when valid, fall back robustly to 1, and clamp to min 1.-# Get the default value for save_steps based on the available number of GPUs -GPU_COUNT=$(python -c "import torch; print(torch.cuda.device_count())") -# Calculate save_steps -DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT)) +# Get the default value for save_steps based on the available number of GPUs +# Prefer NUM_GPU if provided; else try torch; finally fall back to 1. +if [[ -n "${NUM_GPU:-}" && "${NUM_GPU}" =~ ^[0-9]+$ && ${NUM_GPU} -gt 0 ]]; then + GPU_COUNT=${NUM_GPU} +else + GPU_COUNT=$(python - <<'PY' +try: + import torch + print(torch.cuda.device_count() or 0) +except Exception: + print(0) +PY + ) || GPU_COUNT=0 + [[ -z "$GPU_COUNT" || "$GPU_COUNT" -le 0 ]] && GPU_COUNT=1 +fi +# Calculate save_steps safely (min 1) +DEFAULT_SAVE_STEPS=$((8192 / GPU_COUNT)) +(( DEFAULT_SAVE_STEPS < 1 )) && DEFAULT_SAVE_STEPS=1modelopt/torch/export/plugins/hf_spec_export.py (6)
18-20
: Record transformers version (best-effort) for traceability.Populate transformers_version to aid downstream tooling.
@@ -import torch +import torch import torch.nn as nn +try: + import transformers as _hf + _TRANSFORMERS_VERSION = getattr(_hf, "__version__", None) +except Exception: + _TRANSFORMERS_VERSION = None
76-79
: Harden lm_head fallback; raise if neither drafter nor base provides it.Prevents silent KeyError from model.state_dict()["lm_head.weight"].
@@ - if "eagle_lm_head.weight" not in eagle_state: - export_state_dict["lm_head.weight"] = model.state_dict()["lm_head.weight"] + if "eagle_lm_head.weight" not in eagle_state: + base_state = model.state_dict() + if "lm_head.weight" in base_state: + export_state_dict["lm_head.weight"] = base_state["lm_head.weight"] + else: + raise KeyError("Missing 'eagle_lm_head.weight' in drafter and 'lm_head.weight' in base model.")
127-134
: Guard eagle_config/config access before getattr.Avoids AttributeError when either container is missing.
@@ - def _get_config_from_eagle_config_or_base_config(key: str, model: nn.Module): - if getattr(model.eagle_config, key, None) is not None: - return getattr(model.eagle_config, key) - elif getattr(model.config, key, None) is not None: - return getattr(model.config, key) - else: - return None + def _get_config_from_eagle_config_or_base_config(key: str, model: nn.Module): + eagle_cfg = getattr(model, "eagle_config", None) + if eagle_cfg is not None and getattr(eagle_cfg, key, None) is not None: + return getattr(eagle_cfg, key) + base_cfg = getattr(model, "config", None) + if base_cfg is not None and getattr(base_cfg, key, None) is not None: + return getattr(base_cfg, key) + return None
89-108
: Preserve original config fields and merge; set transformers_version.Avoid dropping unknown config keys and persist the detected Transformers version.
@@ - "transformers_version": None, + "transformers_version": _TRANSFORMERS_VERSION, @@ - return template_config + # Merge with original to preserve unknown keys. + merged = {**config_data, **template_config} + merged["eagle_config"] = { + **config_data.get("eagle_config", {}), + **template_config["eagle_config"], + } + return mergedAlso applies to: 149-149
84-88
: set_config_if_spec_decoding crashes on vanilla HF models.Direct len(model._modelopt_state) will raise when the attribute is absent.
@@ - if len(model._modelopt_state) != 1 or model._modelopt_state[0][0] != "eagle": + opt_modes = getattr(model, "_modelopt_state", None) + if ( + not isinstance(opt_modes, (list, tuple)) + or len(opt_modes) != 1 + or opt_modes[0][0] != "eagle" + ): # return as is return config_data
51-65
: Guard for missing eagle_module to avoid AttributeError.Accessing model.eagle_module without a presence check can crash on non-Eagle models.
@@ - # Check if the state dict keys match - _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"]) + # Ensure eagle_module exists + if not hasattr(model, "eagle_module"): + return post_state_dict + # Check if the state dict keys match + _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])modelopt/torch/export/unified_export_hf.py (4)
494-498
: Fix _quant_applied logic; make it reliable.Current check can return True for empty configs and is used nowhere. Tighten and use it.
@@ -def _quant_applied(hf_quant_config: dict) -> bool: - """Check if any quantization is applied.""" - q = hf_quant_config.get("quantization", {}) - return not (q.get("quant_algo") == QUANTIZATION_NONE and not q.get("quantized_layers")) +def _quant_applied(hf_quant_config: dict) -> bool: + """Return True iff any quantization is configured.""" + q = hf_quant_config.get("quantization") or {} + algo = q.get("quant_algo") + layers = q.get("quantized_layers") + return (algo is not None and algo != QUANTIZATION_NONE) or bool(layers)
519-523
: Don’t emit empty hf_quant_config.json.Only write when quantization is actually applied; also use Path for portability.
@@ - # Save hf_quant_config.json for backward compatibility - with open(f"{export_dir}/hf_quant_config.json", "w") as file: - json.dump(hf_quant_config, file, indent=4) + # Save hf_quant_config.json only when quantization is applied + if _quant_applied(hf_quant_config): + with open(Path(export_dir) / "hf_quant_config.json", "w") as file: + json.dump(hf_quant_config, file, indent=4)
526-527
: Guard spec-decoding post-state transform when ModelOpt state is absent.Prevents AttributeError on vanilla HF models.
@@ - post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict) + if hasattr(model, "_modelopt_state"): + post_state_dict = rename_and_prune_if_spec_decoding(model, post_state_dict)
539-544
: Preserve quantization_config and avoid spec-config clobbering; gate on quant applied.Move assignment after spec-config transform and gate it. Also remove stale key when no-quant.
@@ - config_data["quantization_config"] = hf_quant_config - - config_data = set_config_if_spec_decoding(model, config_data) + if hasattr(model, "_modelopt_state"): + config_data = set_config_if_spec_decoding(model, config_data) + if _quant_applied(hf_quant_config): + config_data["quantization_config"] = hf_quant_config + else: + config_data.pop("quantization_config", None)
🧹 Nitpick comments (4)
examples/speculative_decoding/launch_train.sh (4)
111-116
: Quote config path to handle spaces and special chars.Unquoted paths can break argument parsing.
- SPECULATIVE_ARGS="--eagle_config $EAGLE_CONFIG" + SPECULATIVE_ARGS="--eagle_config \"$EAGLE_CONFIG\""
130-153
: Quote key CLI values to be robust to spaces.Quote MODEL, OUTPUT_DIR, and DATA in the assembled command.
- --model_name_or_path $MODEL \ + --model_name_or_path \"$MODEL\" \ @@ - --output_dir $OUTPUT_DIR \ + --output_dir \"$OUTPUT_DIR\" \ @@ - --data_path $DATA \ + --data_path \"$DATA\" \
103-106
: Remove or wire unused variables.
REDRAFTER_TOKENS
andREDRAFTER_NUM_LAYERS
are set but unused.
77-80
: Improve invalid-argument error message.Print the full offending token; current expansion may mangle it.
- >&2 printf "Error: Invalid argument ${1#*=}\n" + >&2 printf "Error: Invalid argument: %s\n" "$1"
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (14)
examples/speculative_decoding/README.md
(2 hunks)examples/speculative_decoding/SLURM_prepare_data.md
(1 hunks)examples/speculative_decoding/ar_validate.py
(2 hunks)examples/speculative_decoding/calibrate_draft_vocab.py
(2 hunks)examples/speculative_decoding/eagle_config.json
(1 hunks)examples/speculative_decoding/export_hf_checkpoint.py
(1 hunks)examples/speculative_decoding/launch_train.sh
(1 hunks)examples/speculative_decoding/main.py
(3 hunks)examples/speculative_decoding/server_generate.py
(1 hunks)examples/speculative_decoding/train_eagle3_and_export.sh
(1 hunks)modelopt/torch/export/plugins/__init__.py
(1 hunks)modelopt/torch/export/plugins/hf_spec_export.py
(1 hunks)modelopt/torch/export/unified_export_hf.py
(4 hunks)modelopt/torch/speculative/plugins/transformers.py
(11 hunks)
✅ Files skipped from review due to trivial changes (1)
- examples/speculative_decoding/SLURM_prepare_data.md
🚧 Files skipped from review as they are similar to previous changes (10)
- examples/speculative_decoding/ar_validate.py
- examples/speculative_decoding/calibrate_draft_vocab.py
- examples/speculative_decoding/eagle_config.json
- modelopt/torch/export/plugins/init.py
- examples/speculative_decoding/server_generate.py
- examples/speculative_decoding/main.py
- examples/speculative_decoding/README.md
- examples/speculative_decoding/export_hf_checkpoint.py
- examples/speculative_decoding/train_eagle3_and_export.sh
- modelopt/torch/speculative/plugins/transformers.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, EAGLE1_DEFAULT_CFG in config.py already uses deepcopy(default_eagle_config) in its definition, making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
modelopt/torch/export/plugins/hf_spec_export.py
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
modelopt/torch/export/plugins/hf_spec_export.py
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, the EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG configurations in config.py are already deepcopied, so additional deepcopy calls are not needed when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
modelopt/torch/export/plugins/hf_spec_export.py
🧬 Code graph analysis (1)
modelopt/torch/export/unified_export_hf.py (1)
modelopt/torch/export/plugins/hf_spec_export.py (2)
rename_and_prune_if_spec_decoding
(51-80)set_config_if_spec_decoding
(83-149)
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (3)
- GitHub Check: linux
- GitHub Check: code-quality
- GitHub Check: build-docs
🔇 Additional comments (3)
examples/speculative_decoding/launch_train.sh (1)
151-152
: Validate required--data
earlyInsert immediately after
export TOKENIZERS_PARALLELISM=False
inexamples/speculative_decoding/launch_train.sh
:if [[ -z "${DATA:-}" ]]; then echo "Error: --data PATH is required." >&2 exit 1 fimodelopt/torch/export/plugins/hf_spec_export.py (1)
54-61
: Nice: tolerant gate on _modelopt_state in rename path.Good defensive checks; keeps non-Eagle paths unaffected.
modelopt/torch/export/unified_export_hf.py (1)
56-56
: Import placement is good; plugin usage clear.Spec-decoding integration point is well-scoped.
--fsdp_transformer_layer_cls_to_wrap*) | ||
if [[ "$1" != *=* ]]; then shift; fi | ||
FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP="${1#*=}" | ||
;; | ||
--num_gpu*) | ||
if [[ "$1" != *=* ]]; then shift; fi | ||
NUM_GPU="${1#*=}" | ||
;; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parsed FSDP arg is never forwarded to the training command.
--fsdp_transformer_layer_cls_to_wrap
is parsed but not passed to main.py
, so user input is ignored.
@@
--num_gpu*)
if [[ "$1" != *=* ]]; then shift; fi
NUM_GPU="${1#*=}"
;;
@@
fi
@@
export TOKENIZERS_PARALLELISM=False
+FSDP_ARGS=""
+if [[ -n "${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP:-}" ]]; then
+ FSDP_ARGS="--fsdp_transformer_layer_cls_to_wrap ${FSDP_TRANSFORMER_LAYER_CLS_TO_WRAP}"
+fi
CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
@@
--data_path $DATA \
- $SPECULATIVE_ARGS
+ $SPECULATIVE_ARGS \
+ $FSDP_ARGS
"
Also applies to: 130-153
if [[ "$NUM_GPU" == 1 ]]; then | ||
MULTI_GPU="" | ||
else | ||
MULTI_GPU="--multi_gpu" | ||
fi | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
💡 Verification agent
❓ Verification inconclusive
Multi-GPU launch: pass explicit process count to Accelerate; --multi_gpu
may be ignored on newer versions.
Safer to specify --num_processes "${NUM_GPU}"
and drop the custom flag.
-if [[ "$NUM_GPU" == 1 ]]; then
- MULTI_GPU=""
-else
- MULTI_GPU="--multi_gpu"
-fi
+LAUNCH_OPTS="--mixed_precision bf16"
+if [[ "${NUM_GPU}" -gt 1 ]]; then
+ LAUNCH_OPTS+=" --num_processes ${NUM_GPU}"
+fi
@@
-CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
+CMD="accelerate launch $LAUNCH_OPTS main.py \
Also applies to: 130-131
Use explicit --num_processes
instead of --multi_gpu
for multi-GPU runs
Accelerate launch supports --num_processes=<N>
alone to spawn N GPUs (and implicitly use MULTI_GPU) without requiring --multi_gpu
(huggingface.co, modeldatabase.com)
-if [[ "$NUM_GPU" == 1 ]]; then
- MULTI_GPU=""
-else
- MULTI_GPU="--multi_gpu"
-fi
+LAUNCH_OPTS="--mixed_precision bf16"
+if [[ "${NUM_GPU}" -gt 1 ]]; then
+ LAUNCH_OPTS+=" --num_processes ${NUM_GPU}"
+fi
@@
-CMD="accelerate launch $MULTI_GPU --mixed_precision bf16 main.py \
+CMD="accelerate launch $LAUNCH_OPTS main.py \
Also update the same pattern at lines 130–131.
Committable suggestion skipped: line range outside the PR's diff.
🤖 Prompt for AI Agents
In examples/speculative_decoding/launch_train.sh around lines 122–127 (and also
update the same pattern at lines 130–131), replace the current multi-GPU flag
logic that sets MULTI_GPU="--multi_gpu" with an explicit process-count flag:
when NUM_GPU==1 keep MULTI_GPU empty, otherwise set
MULTI_GPU="--num_processes=$NUM_GPU"; update any subsequent invocations that
previously relied on --multi_gpu to use this MULTI_GPU variable so Accelerate is
launched with --num_processes=<N> instead of --multi_gpu.
/ok to test faf36b7 |
/ok to test 6b2411e |
import torch.nn as nn | ||
|
||
EAGLE_MODELOPT_TO_OFFICIAL = { | ||
"required": { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is this from megatron to HF?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is the param name mapping from modelopt::HFEagleModel to eagle3 official checkpoint. I bielieve both megatron and HF export to the same format.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The latter is also consistent with our previous released eagle checkpoints.
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
Signed-off-by: h-guo18 <[email protected]>
6b2411e
to
854564e
Compare
/ok to test 854564e |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Actionable comments posted: 1
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
examples/speculative_decoding/server_generate.py (2)
138-154
: Use the Completions API in non-chat mode (current call is inconsistent and likely broken).You're passing
prompt=
toclient.chat.completions.create
and later reading.text
. Switch toclient.completions.create
and use the returned.text
.- response = client.chat.completions.create( + response = client.completions.create( model=model_name, prompt=prompt, max_tokens=args.max_tokens, temperature=args.temperature, ignore_eos=False, skip_special_tokens=False, spaces_between_special_tokens=False, ) - response = response.choices[0].text.strip() + response_text = response.choices[0].text.strip() with open(args.output_path, "a") as f: # write in share gpt format if args.log_empty_conversations: - to_write = {"conversation_id": idx, "text": prompt + response} + to_write = {"conversation_id": idx, "text": prompt + response_text} else: - to_write = {"text": prompt + response} + to_write = {"text": prompt + response_text} f.write(json.dumps(to_write) + "\n")
155-159
: Avoid UnboundLocalError when printingprompt
in exceptions.
prompt
is undefined in chat mode or when the error occurs before assignment.- except Exception as e: - print(e) - print(prompt) - print("Failed to generate data") + except Exception as e: + print(e) + if "prompt" in locals(): + print(prompt) + print("Failed to generate data")
♻️ Duplicate comments (6)
examples/speculative_decoding/server_generate.py (1)
49-49
: Fix boolean CLI parsing for --chat.
type=bool
treats any non-empty string (even "False") as True. Use BooleanOptionalAction so--chat/--no-chat
work correctly.-parser.add_argument("--chat", default=True, type=bool, help="Use chat mode") +parser.add_argument( + "--chat", + action=argparse.BooleanOptionalAction, + default=True, + help="Use chat mode (use --no-chat to disable)", +)modelopt/torch/export/plugins/hf_spec_export.py (4)
89-125
: Preserve originalconfig_data
fields and settransformers_version
; deep-mergeeagle_config
.Current code replaces the entire config, dropping unknown fields. Merge defaults into the original and record the Transformers version.
@@ - # This is the config keys in official checkpoint. + # This is the config keys in official checkpoint. template_config = { @@ - "transformers_version": None, + "transformers_version": getattr(transformers, "__version__", None), @@ - for key in template_config: + for key in template_config: value = template_config[key] @@ - template_config[key] = new_value + template_config[key] = new_value @@ - return template_config + # Merge: keep any unknown keys from original config_data + merged = {**config_data, **template_config} + # Deep-merge nested eagle_config + merged["eagle_config"] = { + **config_data.get("eagle_config", {}), + **template_config["eagle_config"], + } + return mergedAlso applies to: 135-149
76-79
: Harden fallback forlm_head.weight
export.Explicitly check base model state and raise a clear error if neither key exists.
- # TODO: (hg) this is a temp fix. Find cleaner way to do this. - if "eagle_lm_head.weight" not in eagle_state: - export_state_dict["lm_head.weight"] = model.state_dict()["lm_head.weight"] + # TODO: (hg) this is a temp fix. Find cleaner way to do this. + if "eagle_lm_head.weight" not in eagle_state: + base_state = model.state_dict() + if "lm_head.weight" in base_state: + export_state_dict["lm_head.weight"] = base_state["lm_head.weight"] + else: + raise KeyError( + "Missing 'eagle_lm_head.weight' in draft and 'lm_head.weight' in base model." + )
127-134
: Guard access toeagle_config
/config
to avoid AttributeError.- def _get_config_from_eagle_config_or_base_config(key: str, model: nn.Module): - if getattr(model.eagle_config, key, None) is not None: - return getattr(model.eagle_config, key) - elif getattr(model.config, key, None) is not None: - return getattr(model.config, key) - else: - return None + def _get_config_from_eagle_config_or_base_config(key: str, model: nn.Module): + eagle_cfg = getattr(model, "eagle_config", None) + if eagle_cfg is not None and getattr(eagle_cfg, key, None) is not None: + return getattr(eagle_cfg, key) + base_cfg = getattr(model, "config", None) + if base_cfg is not None and getattr(base_cfg, key, None) is not None: + return getattr(base_cfg, key) + return None
51-66
: Guard for Eagle mode and presence ofeagle_module
before accessing it.If
_modelopt_state
indicates Eagle buteagle_module
is missing, this raisesAttributeError
. Early-return safely.def rename_and_prune_if_spec_decoding(model: nn.Module, post_state_dict: dict): @@ - opt_modes = getattr(model, "_modelopt_state", None) + opt_modes = getattr(model, "_modelopt_state", None) if ( not isinstance(opt_modes, (list, tuple)) or len(opt_modes) != 1 or opt_modes[0][0] != "eagle" ): # if there's other opts, return as is return post_state_dict + if not hasattr(model, "eagle_module"): + # Not an Eagle-wrapped model; return unchanged + return post_state_dict @@ - _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"]) + _check_state_dict_keys_match(model.eagle_module, EAGLE_MODELOPT_TO_OFFICIAL["required"])examples/speculative_decoding/README.md (1)
28-33
: Provide concrete install commands (replace placeholder).Make this copy-pastable; include PyPI install and editable source install with HF extras.
-Install Modelopt with `hf` dependencies and other requirements for this example: +Install ModelOpt with `hf` extras and example requirements: @@ -```bash -pip install -e ... -pip install -r requirements.txt -``` +```bash +# Option A: Install published package (recommended) +pip install "nvidia-modelopt[hf]" + +# Option B: Install from source +pip install -e .[hf] + +# Example-specific deps +pip install -r requirements.txt +```
🧹 Nitpick comments (6)
examples/speculative_decoding/server_generate.py (3)
56-57
: Use list default for --system_prompt to avoid odd joins.Defaulting to "" makes
" ".join(args.system_prompt)
iterate characters. Use an empty list.-parser.add_argument("--system_prompt", nargs="+", type=str, default="", help="System prompt") +parser.add_argument("--system_prompt", nargs="+", type=str, default=[], help="System prompt")
43-44
: Lower default thread count to a safer value.256 workers can easily overwhelm local servers and hit rate limits. Suggest 32–64 by default.
-parser.add_argument( - "--num_threads", type=int, default=256, help="Number of threads to use (batch size)" +parser.add_argument( + "--num_threads", type=int, default=64, help="Number of threads to use (batch size)" )
60-65
: Use context manager for JSON load for consistency and file handle safety.-if args.data_path.endswith("jsonl"): - with open(args.data_path) as f: - data = [json.loads(line) for line in f] -else: - data = json.load(open(args.data_path)) +if args.data_path.endswith("jsonl"): + with open(args.data_path) as f: + data = [json.loads(line) for line in f] +else: + with open(args.data_path) as f: + data = json.load(f)modelopt/torch/export/plugins/hf_spec_export.py (1)
18-20
: Record Transformers version for provenance.Import transformers to populate
transformers_version
in the exported config.import torch import torch.nn as nn +import transformers
examples/speculative_decoding/README.md (2)
47-48
: Brand consistency: “ModelOpt” capitalization.Use “ModelOpt” consistently.
-This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in Modelopt. Specifically, it +This one-line command runs a minimal example workflow of training and exporting an EAGLE draft model in ModelOpt. Specifically, it
229-236
: Validate or trim the Support Matrix to avoid drift.Section can get stale quickly; either confirm the list against current HF path support or drop the table and link to dynamic docs.
Would you like me to open a follow-up PR to auto-generate this table from a tested registry?
📜 Review details
Configuration used: CodeRabbit UI
Review profile: CHILL
Plan: Pro
📒 Files selected for processing (14)
examples/speculative_decoding/README.md
(2 hunks)examples/speculative_decoding/SLURM_prepare_data.md
(1 hunks)examples/speculative_decoding/ar_validate.py
(2 hunks)examples/speculative_decoding/calibrate_draft_vocab.py
(2 hunks)examples/speculative_decoding/eagle_config.json
(1 hunks)examples/speculative_decoding/export_hf_checkpoint.py
(1 hunks)examples/speculative_decoding/launch_train.sh
(1 hunks)examples/speculative_decoding/main.py
(3 hunks)examples/speculative_decoding/server_generate.py
(1 hunks)examples/speculative_decoding/train_eagle3_and_export.sh
(1 hunks)modelopt/torch/export/plugins/__init__.py
(1 hunks)modelopt/torch/export/plugins/hf_spec_export.py
(1 hunks)modelopt/torch/export/unified_export_hf.py
(3 hunks)modelopt/torch/speculative/plugins/transformers.py
(11 hunks)
🚧 Files skipped from review as they are similar to previous changes (11)
- examples/speculative_decoding/eagle_config.json
- examples/speculative_decoding/export_hf_checkpoint.py
- examples/speculative_decoding/SLURM_prepare_data.md
- examples/speculative_decoding/main.py
- examples/speculative_decoding/ar_validate.py
- modelopt/torch/export/unified_export_hf.py
- modelopt/torch/export/plugins/init.py
- examples/speculative_decoding/launch_train.sh
- examples/speculative_decoding/train_eagle3_and_export.sh
- examples/speculative_decoding/calibrate_draft_vocab.py
- modelopt/torch/speculative/plugins/transformers.py
🧰 Additional context used
🧠 Learnings (3)
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, EAGLE1_DEFAULT_CFG in config.py already uses deepcopy(default_eagle_config) in its definition, making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
modelopt/torch/export/plugins/hf_spec_export.py
examples/speculative_decoding/README.md
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, both EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG in config.py already use deepcopied configurations (either directly via deepcopy() or through variables that were already deepcopied), making additional deepcopy calls unnecessary when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
modelopt/torch/export/plugins/hf_spec_export.py
examples/speculative_decoding/README.md
📚 Learning: 2025-09-05T19:10:36.359Z
Learnt from: yeyu-nvidia
PR: NVIDIA/TensorRT-Model-Optimizer#295
File: tests/gpu/torch/speculative/plugins/test_speculative_megatron_modules.py:35-39
Timestamp: 2025-09-05T19:10:36.359Z
Learning: In the TensorRT-Model-Optimizer codebase, the EAGLE1_DEFAULT_CFG and EAGLE3_DEFAULT_CFG configurations in config.py are already deepcopied, so additional deepcopy calls are not needed when using ALGO_TO_CONFIG mapping in tests.
Applied to files:
modelopt/torch/export/plugins/hf_spec_export.py
examples/speculative_decoding/README.md
⏰ Context from checks skipped due to timeout of 90000ms. You can increase the timeout in your CodeRabbit configuration to a maximum of 15 minutes (900000ms). (4)
- GitHub Check: linux
- GitHub Check: wait-checks / wait
- GitHub Check: build-docs
- GitHub Check: code-quality
/ok to test 854564e |
Signed-off-by: h-guo18 <[email protected]>
What does this PR do?
Type of change: New feature, New example;
Overview:
Add necessary changes and new features for updating eagle example, including:
Feat: add export support for speculative decoding models in hf_unified_export;
Update several arguments names and default values for example simplicity;
Added a few new scripts and renamed some files for the example;
Rewrote the README:
Rearrange content order; Introduce a "simplified workflow" section;
Provided more details in the "Complete workflow" section.
Removed deprecated contents: Nemo link and notebook example;
Usage
See README.md for usage.
# Add a code snippet demonstrating how to use this
Testing
Tested dummy training + ar_validate + export with:
Tested deployment on:
Before your PR is "Ready for review"
Additional Information
Summary by CodeRabbit
New Features
Documentation
Changes
Bug Fixes